{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "67cc7844",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/finetuning/embeddings/finetune_embedding_adapter.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03f5ac7e-d36d-4879-959a-1af414fe4c02",
   "metadata": {},
   "source": [
    "# Finetuning an Adapter on Top of any Black-Box Embedding Model\n",
    "\n",
    "\n",
    "We have capabilities in LlamaIndex allowing you to fine-tune an adapter on top of embeddings produced from any model (sentence_transformers, OpenAI, and more). \n",
    "\n",
    "This allows you to transform your embedding representations into a new latent space that's optimized for retrieval over your specific data and queries. This can lead to small increases in retrieval performance that in turn translate to better performing RAG systems.\n",
    "\n",
    "We do this via our `EmbeddingAdapterFinetuneEngine` abstraction. We fine-tune three types of adapters:\n",
    "- Linear\n",
    "- 2-Layer NN\n",
    "- Custom NN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ab6c5cc-8b31-41cd-95aa-6d60fbefff9b",
   "metadata": {},
   "source": [
    "## Generate Corpus\n",
    "\n",
    "We use our helper abstractions, `generate_qa_embedding_pairs`, to generate our training and evaluation dataset. This function takes in any set of text nodes (chunks) and generates a structured dataset containing (question, context) pairs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35c49d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-embeddings-openai\n",
    "%pip install llama-index-embeddings-adapter\n",
    "%pip install llama-index-finetuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b36f73f-83b1-4715-bd4d-7ce1353d1a19",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "from llama_index.core import SimpleDirectoryReader\n",
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "from llama_index.core.schema import MetadataMode"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2fc4bd24",
   "metadata": {},
   "source": [
    "Download Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ae97522",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p 'data/10k/'\n",
    "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf' -O 'data/10k/uber_2021.pdf'\n",
    "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58c43042-2ed1-4ab7-a53d-7f65dd856f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_FILES = [\"./data/10k/lyft_2021.pdf\"]\n",
    "VAL_FILES = [\"./data/10k/uber_2021.pdf\"]\n",
    "\n",
    "TRAIN_CORPUS_FPATH = \"./data/train_corpus.json\"\n",
    "VAL_CORPUS_FPATH = \"./data/val_corpus.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c7e38d0-39ff-44e2-ab7f-fded56dcd707",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_corpus(files, verbose=False):\n",
    "    if verbose:\n",
    "        print(f\"Loading files {files}\")\n",
    "\n",
    "    reader = SimpleDirectoryReader(input_files=files)\n",
    "    docs = reader.load_data()\n",
    "    if verbose:\n",
    "        print(f\"Loaded {len(docs)} docs\")\n",
    "\n",
    "    parser = SentenceSplitter()\n",
    "    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"Parsed {len(nodes)} nodes\")\n",
    "\n",
    "    return nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1257dce-0be1-42c4-9346-a1fe68505fdd",
   "metadata": {},
   "source": [
    "We do a very naive train/val split by having the Lyft corpus as the train dataset, and the Uber corpus as the val dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffd6d8af-5382-48b8-8a7d-98a03d2f150d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading files ['../../../examples/data/10k/lyft_2021.pdf']\n",
      "Loaded 238 docs\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0050508975982666016,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 37,
       "postfix": null,
       "prefix": "Parsing documents into nodes",
       "rate": null,
       "total": 238,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "54a44179a71f48ebb4481ccfa8a857a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Parsing documents into nodes:   0%|          | 0/238 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parsed 349 nodes\n",
      "Loading files ['../../../examples/data/10k/uber_2021.pdf']\n",
      "Loaded 307 docs\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0017158985137939453,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 37,
       "postfix": null,
       "prefix": "Parsing documents into nodes",
       "rate": null,
       "total": 307,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "869943c416f948c5be35c86ebc5091e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Parsing documents into nodes:   0%|          | 0/307 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parsed 418 nodes\n"
     ]
    }
   ],
   "source": [
    "train_nodes = load_corpus(TRAIN_FILES, verbose=True)\n",
    "val_nodes = load_corpus(VAL_FILES, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1893a5f1-6fdf-473b-80ea-5ea3df5681a7",
   "metadata": {},
   "source": [
    "### Generate synthetic queries\n",
    "\n",
    "Now, we use an LLM (gpt-3.5-turbo) to generate questions using each text chunk in the corpus as context.\n",
    "\n",
    "Each pair of (generated question, text chunk used as context) becomes a datapoint in the finetuning dataset (either for training or evaluation)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee1c892e-e27d-49f6-96d4-b99af330aed8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.finetuning import generate_qa_embedding_pairs\n",
    "from llama_index.core.evaluation import EmbeddingQAFinetuneDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7330fb1f-cfb4-4b9b-b614-06910d5330b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = generate_qa_embedding_pairs(train_nodes)\n",
    "val_dataset = generate_qa_embedding_pairs(val_nodes)\n",
    "\n",
    "train_dataset.save_json(\"train_dataset.json\")\n",
    "val_dataset.save_json(\"val_dataset.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "909ca757-bf02-4304-a59e-7d61a12a67df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# [Optional] Load\n",
    "train_dataset = EmbeddingQAFinetuneDataset.from_json(\"train_dataset.json\")\n",
    "val_dataset = EmbeddingQAFinetuneDataset.from_json(\"val_dataset.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b619e9a6-4795-4ff5-bb48-ae2c50324eb2",
   "metadata": {},
   "source": [
    "## Run Embedding Finetuning\n",
    "\n",
    "We then fine-tune our linear adapter on top of an existing embedding model. We import our new `EmbeddingAdapterFinetuneEngine` abstraction, which takes in an existing embedding model and a set of training parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ade12658-40d9-4108-abd5-c978542001d3",
   "metadata": {},
   "source": [
    "#### Fine-tune bge-small-en (default)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7574d8b-e287-4cc5-9e8c-643c365755a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.finetuning import EmbeddingAdapterFinetuneEngine\n",
    "from llama_index.core.embeddings import resolve_embed_model\n",
    "import torch\n",
    "\n",
    "base_embed_model = resolve_embed_model(\"local:BAAI/bge-small-en\")\n",
    "\n",
    "finetune_engine = EmbeddingAdapterFinetuneEngine(\n",
    "    train_dataset,\n",
    "    base_embed_model,\n",
    "    model_output_path=\"model_output_test\",\n",
    "    # bias=True,\n",
    "    epochs=4,\n",
    "    verbose=True,\n",
    "    # optimizer_class=torch.optim.SGD,\n",
    "    # optimizer_params={\"lr\": 0.01}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aeb22015-e4b7-44ae-b3b1-63cf7797d457",
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63c96d50-110d-43de-acf2-8df8a7153ae9",
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_model = finetune_engine.get_finetuned_model()\n",
    "\n",
    "# alternatively import model\n",
    "from llama_index.core.embeddings import LinearAdapterEmbeddingModel\n",
    "\n",
    "# embed_model = LinearAdapterEmbeddingModel(base_embed_model, \"model_output_test\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20d9861b-a41e-4f3c-8374-0d7e2a46103d",
   "metadata": {},
   "source": [
    "## Evaluate Finetuned Model\n",
    "\n",
    "We compare the fine-tuned model against the base model, as well as against text-embedding-ada-002.\n",
    "\n",
    "We evaluate with two ranking metrics:\n",
    "- **Hit-rate metric**: For each (query, context) pair, we retrieve the top-k documents with the query. It's a hit if the results contain the ground-truth context.\n",
    "- **Mean Reciprocal Rank**: A slightly more granular ranking metric that looks at the \"reciprocal rank\" of the ground-truth context in the top-k retrieved set. The reciprocal rank is defined as 1/rank. Of course, if the results don't contain the context, then the reciprocal rank is 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7c7603a-6715-489f-a991-7efb597cb610",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.embeddings.openai import OpenAIEmbedding\n",
    "from llama_index.core import VectorStoreIndex\n",
    "from llama_index.core.schema import TextNode\n",
    "from tqdm.notebook import tqdm\n",
    "import pandas as pd\n",
    "\n",
    "from eval_utils import evaluate, display_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34e64a54-e97e-459b-a795-75ceb4eae44f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0057108402252197266,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 28,
       "postfix": null,
       "prefix": "Generating embeddings",
       "rate": null,
       "total": 395,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "161b1f18342a4e6d9d735f4f5944cfa5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating embeddings:   0%|          | 0/395 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████| 790/790 [03:03<00:00,  4.30it/s]\n"
     ]
    }
   ],
   "source": [
    "ada = OpenAIEmbedding()\n",
    "ada_val_results = evaluate(val_dataset, ada)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a31b8728-d510-48e5-980e-2f682d85da14",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>retrievers</th>\n",
       "      <th>hit_rate</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ada</td>\n",
       "      <td>0.870886</td>\n",
       "      <td>0.72884</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  retrievers  hit_rate      mrr\n",
       "0        ada  0.870886  0.72884"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display_results([\"ada\"], [ada_val_results])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4ce041a-ead0-4e10-a5f5-fbdfd20f2546",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004221916198730469,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 28,
       "postfix": null,
       "prefix": "Generating embeddings",
       "rate": null,
       "total": 395,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce0b815fb3a043a4b03482f9927c0a3a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating embeddings:   0%|          | 0/395 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████| 790/790 [00:23<00:00, 33.76it/s]\n"
     ]
    }
   ],
   "source": [
    "bge = \"local:BAAI/bge-small-en\"\n",
    "bge_val_results = evaluate(val_dataset, bge)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56697dc4-ec77-4acc-b306-eed45ec23eea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>retrievers</th>\n",
       "      <th>hit_rate</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>bge</td>\n",
       "      <td>0.787342</td>\n",
       "      <td>0.643038</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  retrievers  hit_rate       mrr\n",
       "0        bge  0.787342  0.643038"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display_results([\"bge\"], [bge_val_results])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17163fc1-703f-4377-a583-1f47ed0fb2ae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.005866050720214844,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 37,
       "postfix": null,
       "prefix": "Generating embeddings",
       "rate": null,
       "total": 395,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d99ce17b6ffe425e9ad493a91f073341",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating embeddings:   0%|          | 0/395 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 790/790 [00:21<00:00, 36.95it/s]\n"
     ]
    }
   ],
   "source": [
    "ft_val_results = evaluate(val_dataset, embed_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12442596-3964-4f3d-a6fe-df76e7e4b84a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>retrievers</th>\n",
       "      <th>hit_rate</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ft</td>\n",
       "      <td>0.798734</td>\n",
       "      <td>0.662152</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  retrievers  hit_rate       mrr\n",
       "0         ft  0.798734  0.662152"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display_results([\"ft\"], [ft_val_results])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce4b7478-9c98-4d1c-a7f0-208fcb64be0e",
   "metadata": {},
   "source": [
    "Here we show all the results concatenated together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "840ba090-2880-4c41-ae2d-24697969be81",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>retrievers</th>\n",
       "      <th>hit_rate</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ada</td>\n",
       "      <td>0.870886</td>\n",
       "      <td>0.730105</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>bge</td>\n",
       "      <td>0.787342</td>\n",
       "      <td>0.643038</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ft</td>\n",
       "      <td>0.798734</td>\n",
       "      <td>0.662152</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  retrievers  hit_rate       mrr\n",
       "0        ada  0.870886  0.730105\n",
       "1        bge  0.787342  0.643038\n",
       "2         ft  0.798734  0.662152"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display_results(\n",
    "    [\"ada\", \"bge\", \"ft\"], [ada_val_results, bge_val_results, ft_val_results]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cef6d1d-cf4b-4103-842b-1230c9835922",
   "metadata": {},
   "source": [
    "## Fine-tune a Two-Layer Adapter\n",
    "\n",
    "Let's try fine-tuning a two-layer NN as well! \n",
    "\n",
    "It's a simple two-layer NN with a ReLU activation and a residual layer at the end.\n",
    "\n",
    "We train for 25 epochs - longer than the linear adapter - and preserve checkpoints every 100 steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5fdf226-001b-49bf-b1fb-f96678b1aa34",
   "metadata": {},
   "outputs": [],
   "source": [
    "# requires torch dependency\n",
    "from llama_index.core.embeddings.adapter_utils import TwoLayerNN\n",
    "\n",
    "from llama_index.finetuning import EmbeddingAdapterFinetuneEngine\n",
    "from llama_index.core.embeddings import resolve_embed_model\n",
    "from llama_index.embeddings.adapter import AdapterEmbeddingModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03cf055e-c14f-4faf-8636-f815ab7fedd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_embed_model = resolve_embed_model(\"local:BAAI/bge-small-en\")\n",
    "adapter_model = TwoLayerNN(\n",
    "    384,  # input dimension\n",
    "    1024,  # hidden dimension\n",
    "    384,  # output dimension\n",
    "    bias=True,\n",
    "    add_residual=True,\n",
    ")\n",
    "\n",
    "finetune_engine = EmbeddingAdapterFinetuneEngine(\n",
    "    train_dataset,\n",
    "    base_embed_model,\n",
    "    model_output_path=\"model5_output_test\",\n",
    "    model_checkpoint_path=\"model5_ck\",\n",
    "    adapter_model=adapter_model,\n",
    "    epochs=25,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba3a65b-2eff-446f-a259-e9f9f26aa94e",
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fabb32e2-ef12-4f1d-83a6-a465febfe7e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_model_2layer = finetune_engine.get_finetuned_model(\n",
    "    adapter_cls=TwoLayerNN\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cadeb93a-b733-4e81-9301-5d63c80193df",
   "metadata": {},
   "source": [
    "### Evaluation Results\n",
    "\n",
    "Run the same evaluation script used in the previous section to measure hit-rate/MRR within the two-layer model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9672bdc-cbfb-46ce-8a30-fad4e11a861c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model from checkpoint in the midde\n",
    "embed_model_2layer = AdapterEmbeddingModel(\n",
    "    base_embed_model,\n",
    "    \"model5_output_test\",\n",
    "    TwoLayerNN,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1aa65c4f-f5ec-470d-91bf-5ce7dc16b757",
   "metadata": {},
   "outputs": [],
   "source": [
    "from eval_utils import evaluate, display_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8edd857e-2d67-435c-a7c1-1ffdf57c163f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004235029220581055,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 28,
       "postfix": null,
       "prefix": "Generating embeddings",
       "rate": null,
       "total": 395,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51e7a3950ef644019e0c82add7402f29",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating embeddings:   0%|          | 0/395 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████| 790/790 [00:21<00:00, 36.93it/s]\n"
     ]
    }
   ],
   "source": [
    "ft_val_results_2layer = evaluate(val_dataset, embed_model_2layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fc44340-0f40-49ba-b356-c51a7ee26f86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>retrievers</th>\n",
       "      <th>hit_rate</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ada</td>\n",
       "      <td>0.870886</td>\n",
       "      <td>0.728840</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>bge</td>\n",
       "      <td>0.787342</td>\n",
       "      <td>0.643038</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ft_2layer</td>\n",
       "      <td>0.798734</td>\n",
       "      <td>0.662848</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  retrievers  hit_rate       mrr\n",
       "0        ada  0.870886  0.728840\n",
       "1        bge  0.787342  0.643038\n",
       "2  ft_2layer  0.798734  0.662848"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# comment out if you haven't run ada/bge yet\n",
    "display_results(\n",
    "    [\"ada\", \"bge\", \"ft_2layer\"],\n",
    "    [ada_val_results, bge_val_results, ft_val_results_2layer],\n",
    ")\n",
    "\n",
    "# uncomment if you just want to display the fine-tuned model's results\n",
    "# display_results([\"ft_2layer\"], [ft_val_results_2layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e7bbc01-63e8-4428-ad0a-8ddbbd35ca06",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model from checkpoint in the midde\n",
    "embed_model_2layer_s900 = AdapterEmbeddingModel(\n",
    "    base_embed_model,\n",
    "    \"model5_ck/step_900\",\n",
    "    TwoLayerNN,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43c9d430-37d9-474f-b16e-27ff9a9c9552",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004389047622680664,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 28,
       "postfix": null,
       "prefix": "Generating embeddings",
       "rate": null,
       "total": 395,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "608ef63f35564d79a180f6acc46ea868",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating embeddings:   0%|          | 0/395 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████| 790/790 [00:19<00:00, 40.57it/s]\n"
     ]
    }
   ],
   "source": [
    "ft_val_results_2layer_s900 = evaluate(val_dataset, embed_model_2layer_s900)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e910736-7918-47fa-8286-54d4944341f8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>retrievers</th>\n",
       "      <th>hit_rate</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ada</td>\n",
       "      <td>0.870886</td>\n",
       "      <td>0.728840</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>bge</td>\n",
       "      <td>0.787342</td>\n",
       "      <td>0.643038</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ft_2layer_s900</td>\n",
       "      <td>0.803797</td>\n",
       "      <td>0.667426</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       retrievers  hit_rate       mrr\n",
       "0             ada  0.870886  0.728840\n",
       "1             bge  0.787342  0.643038\n",
       "2  ft_2layer_s900  0.803797  0.667426"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# comment out if you haven't run ada/bge yet\n",
    "display_results(\n",
    "    [\"ada\", \"bge\", \"ft_2layer_s900\"],\n",
    "    [ada_val_results, bge_val_results, ft_val_results_2layer_s900],\n",
    ")\n",
    "\n",
    "# uncomment if you just want to display the fine-tuned model's results\n",
    "# display_results([\"ft_2layer_s900\"], [ft_val_results_2layer_s900])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7e3a5d9-91bc-47c3-8d92-66f8b37bccce",
   "metadata": {},
   "source": [
    "## Try Your Own Custom Model\n",
    "\n",
    "You can define your own custom adapter here! Simply subclass `BaseAdapter`, which is a light wrapper around the `nn.Module` class.\n",
    "\n",
    "You just need to subclass `forward` and `get_config_dict`.\n",
    "\n",
    "Just make sure you're familiar with writing `PyTorch` code :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daeb03b7-2bdd-4e0f-9b1c-eb4a8d801f10",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.embeddings.adapter_utils import BaseAdapter\n",
    "import torch.nn.functional as F\n",
    "from torch import nn, Tensor\n",
    "from typing import Dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1552a664-2941-4f8f-9267-84e357404c83",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomNN(BaseAdapter):\n",
    "    \"\"\"Custom NN transformation.\n",
    "\n",
    "    Is a copy of our TwoLayerNN, showing it here for notebook purposes.\n",
    "\n",
    "    Args:\n",
    "        in_features (int): Input dimension.\n",
    "        hidden_features (int): Hidden dimension.\n",
    "        out_features (int): Output dimension.\n",
    "        bias (bool): Whether to use bias. Defaults to False.\n",
    "        activation_fn_str (str): Name of activation function. Defaults to \"relu\".\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_features: int,\n",
    "        hidden_features: int,\n",
    "        out_features: int,\n",
    "        bias: bool = False,\n",
    "        add_residual: bool = False,\n",
    "    ) -> None:\n",
    "        super(CustomNN, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.hidden_features = hidden_features\n",
    "        self.out_features = out_features\n",
    "        self.bias = bias\n",
    "\n",
    "        self.linear1 = nn.Linear(in_features, hidden_features, bias=True)\n",
    "        self.linear2 = nn.Linear(hidden_features, out_features, bias=True)\n",
    "        self._add_residual = add_residual\n",
    "        # if add_residual, then add residual_weight (init to 0)\n",
    "        self.residual_weight = nn.Parameter(torch.zeros(1))\n",
    "\n",
    "    def forward(self, embed: Tensor) -> Tensor:\n",
    "        \"\"\"Forward pass (Wv).\n",
    "\n",
    "        Args:\n",
    "            embed (Tensor): Input tensor.\n",
    "\n",
    "        \"\"\"\n",
    "        output1 = self.linear1(embed)\n",
    "        output1 = F.relu(output1)\n",
    "        output2 = self.linear2(output1)\n",
    "\n",
    "        if self._add_residual:\n",
    "            output2 = self.residual_weight * output2 + embed\n",
    "\n",
    "        return output2\n",
    "\n",
    "    def get_config_dict(self) -> Dict:\n",
    "        \"\"\"Get config dict.\"\"\"\n",
    "        return {\n",
    "            \"in_features\": self.in_features,\n",
    "            \"hidden_features\": self.hidden_features,\n",
    "            \"out_features\": self.out_features,\n",
    "            \"bias\": self.bias,\n",
    "            \"add_residual\": self._add_residual,\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cfde782-c1fb-4c1c-9ab7-1f6140ae5aea",
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_adapter = CustomNN(\n",
    "    384,  # input dimension\n",
    "    1024,  # hidden dimension\n",
    "    384,  # output dimension\n",
    "    bias=True,\n",
    "    add_residual=True,\n",
    ")\n",
    "\n",
    "finetune_engine = EmbeddingAdapterFinetuneEngine(\n",
    "    train_dataset,\n",
    "    base_embed_model,\n",
    "    model_output_path=\"custom_model_output\",\n",
    "    model_checkpoint_path=\"custom_model_ck\",\n",
    "    adapter_model=custom_adapter,\n",
    "    epochs=25,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afd808c4-aa17-4ad8-9bc7-3bc5666cc075",
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f184f2b-c3e9-4f55-a1c2-96a09d9f9f49",
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_model_custom = finetune_engine.get_finetuned_model(\n",
    "    adapter_cls=CustomAdapter\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b6d5e1b-f6ed-4183-96ae-ff4f959b4ab2",
   "metadata": {},
   "source": [
    "### Evaluation Results\n",
    "\n",
    "Run the same evaluation script used in the previous section to measure hit-rate/MRR."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8843d06b-ccb6-41b7-ab6d-c873d59540d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# [optional] load model manually\n",
    "# embed_model_custom = AdapterEmbeddingModel(\n",
    "#     base_embed_model,\n",
    "#     \"custom_model_ck/step_300\",\n",
    "#     TwoLayerNN,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "535d1c9b-fc32-4121-9fb3-00f4ba5c38aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from eval_utils import evaluate, display_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1f91f85-b65f-4f8a-8fb1-41434f046cd7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004842042922973633,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 28,
       "postfix": null,
       "prefix": "Generating embeddings",
       "rate": null,
       "total": 395,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dad3f2789d054672af07a00821a5d54f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating embeddings:   0%|          | 0/395 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████| 790/790 [00:20<00:00, 37.77it/s]\n"
     ]
    }
   ],
   "source": [
    "ft_val_results_custom = evaluate(val_dataset, embed_model_custom)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "977b8ffa-1ceb-46eb-b966-64a9bc686127",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>retrievers</th>\n",
       "      <th>hit_rate</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ft_custom</td>\n",
       "      <td>0.789873</td>\n",
       "      <td>0.645127</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  retrievers  hit_rate       mrr\n",
       "0  ft_custom  0.789873  0.645127"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display_results([\"ft_custom\"]x, [ft_val_results_custom])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
